import bluepysnap
from scipy.special import expit
import scipy as sc
from scipy.optimize import curve_fit
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt


def get_spindle_amp(sim, filename=None):
    report = sim.spikes["thalamus_neurons"].spike_report.filter(group={"mtype": "Rt_RC"}).report

    times = report.index
    time_start = np.min(times)
    time_stop = np.max(times)
    time_binsize = 5
    bins = np.append(np.arange(time_start, time_stop, time_binsize), time_stop)
    hist, bin_edges = np.histogram(times, bins=bins)
    node_count = report[["ids", "population"]].drop_duplicates().shape[0]
    freq = 1.0 * hist / node_count / (0.001 * time_binsize)
    if filename is not None:
        plt.figure()
        plt.plot(bins[:-1], freq)
        plt.savefig(filename)
    from scipy.signal import find_peaks

    amps = []
    durations = []
    freqs = []
    for step in range(16):
        low, high = 3500 + step * 1000, 4500 + step * 1000
        mask = (bin_edges[:-1] >= low) & (bin_edges[:-1] <= high)
        f = freq[mask]
        t = bins[:-1][mask]
        peaks = find_peaks(f, height=2, width=3)[0]
        if peaks.size <= 2:
            amps.append(0)
            durations.append(0)
            freqs.append(0)
        else:
            amps.append(np.mean(f[peaks]))
            durations.append(t[peaks[-1]] - t[peaks[0]])
            freqs.append((len(peaks) - 1) / (t[peaks[-1]] - t[peaks[0]]) * 1000)
    return {"amps": amps, "durations": durations, "freqs": freqs}


def fit_func(_x, shift_exp, slope_exp=4, amp_exp=60):
    return expit((_x - shift_exp) / slope_exp) * amp_exp + 2


if __name__ == "__main__":

    colors = ["C0", "C1", "C2", "C3", "C4", "C5", "C3", "C7", "C8", "C9", "C10"]
    base_path = Path("../../../simulations/06_09_Iahp_variants/")
    shifts = []
    shift_errs = []
    fig_all = plt.figure(figsize=(5, 3))
    ax_all = plt.gca()
    fig_all_d = plt.figure(figsize=(5, 2))
    ax_all_d = plt.gca()
    fig_all_f = plt.figure(figsize=(5, 2))
    ax_all_f = plt.gca()
    cmap = plt.get_cmap("viridis")
    colors = cmap(np.linspace(0, 1, 5))
    for i, (model, var) in enumerate(
        zip(
            [
                "lowest-Iahp",
                "low-Iahp",
                "Spp",
                "med-Iahp",
                "high-Iahp",
            ],
            [
                14,
                15,
                2,
                16,
                17,
            ],
        )
    ):
        print(model)
        amps_mean = []
        amps_std = []
        durations_mean = []
        durations_std = []
        freqs_mean = []
        freqs_std = []
        percs = [40, 50, 60, 70, 75, 80, 85, 90, 95, 100]
        _percs = []
        for perc in percs:
            sim_path = (
                base_path
                / f"variant{var}_{model}/variant{var}_{model}_SUPER_LONG_{perc}percent_CT_up_down/simulation_config.json"
            )
            if not sim_path.exists():
                continue
            _percs.append(perc)
            sim = bluepysnap.Simulation(sim_path)
            data = get_spindle_amp(sim, filename=f"figures/spindles_{model}_{perc}.pdf")
            amps_mean.append(np.mean(data["amps"]))
            amps_std.append(np.std(data["amps"]))
            durations_mean.append(np.mean(data["durations"]))
            durations_std.append(np.std(data["durations"]))
            freqs_mean.append(np.mean(data["freqs"]))
            freqs_std.append(np.std(data["freqs"]))

        amps_mean = np.array(amps_mean)
        amps_std = np.array(amps_std)
        popt, pcov = curve_fit(fit_func, _percs, amps_mean, bounds=([50], [120]))
        shift_errs.append(np.sqrt(np.diag(pcov))[0])
        print(popt)
        shifts.append(popt[0])
        # if i in [0, 1, 2, 6]:
        ax_all_d.errorbar(_percs, durations_mean, durations_std, c=colors[i], label=model)
        ax_all_f.errorbar(_percs, freqs_mean, freqs_std, c=colors[i], label=model)
        print(amps_std)
        ax_all.errorbar(_percs, amps_mean, amps_std, c=colors[i], ls="")
        ax_all.scatter(_percs, amps_mean, label=model, c=colors[i], s=15)
        x = np.linspace(10, 105, 1000)
        ax_all.plot(x, fit_func(x, *popt), color=colors[i], lw=0.8)

        ax_all.set_ylabel("Mean spindle freq (Hz)")
        ax_all.set_xlabel("percent CT activated")
        ax_all.set_xlim(39, 105)
        ax_all.set_ylim(0, 100)
        ax_all.legend()
        fig_all.tight_layout()
        fig_all.savefig("scan_spp.pdf")

        ax_all_d.set_ylabel("Spindle duration [ms]")
        ax_all_d.set_xlabel("percent CT activated")
        ax_all_d.set_xlim(39, 105)
        ax_all_d.set_ylim(0, 1000)
        ax_all_d.legend()
        fig_all_d.tight_layout()
        fig_all_d.savefig("scan_duration_spp.pdf")

        ax_all_f.set_ylabel("Spindle frequency [ms]")
        ax_all_f.set_xlabel("percent CT activated")
        ax_all_f.set_xlim(39, 105)
        ax_all_f.set_ylim(0, 15)
        ax_all_f.legend()
        fig_all_f.tight_layout()
        fig_all_f.savefig("scan_freq_spp.pdf")

    iahps = [0.005, 0.01, 0.015, 0.020, 0.025]
    plt.figure(figsize=(5, 2))
    x = np.linspace(0.002, 0.027, 10)
    z = np.polyfit(iahps[1:], shifts[1:], 1, w=1.0 / np.array(shift_errs[1:]))
    print(z)
    p = np.poly1d(z)
    plt.plot(x, p(x), c="k")
    plt.scatter(iahps, shifts, c="k")
    plt.legend()
    plt.xlabel("iahp conductances")
    plt.ylabel("CT turning point")
    plt.axis([0.002, 0.027, 40, 100])
    plt.legend()
    plt.tight_layout()
    plt.savefig("fit_spp.pdf")
